import torch
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, BertConfig
from torchcrf import CRF
from .module import IntentClassifier


class ClsBERT(BertPreTrainedModel):
    """
    构建ClsBERT模块
    """
    def __init__(self, config, args, intent_label_lst):
        super(ClsBERT, self).__init__(config)
        # 参数
        self.args = args
        # 意图标签数目
        self.num_intent_labels = len(intent_label_lst)
        # 加载预训练BERT模型-编码器模块
        self.bert = BertModel(config=config)  # Load pretrained bert
        # 意图识别分类层
        self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)

    # 定义CLSBERT模型的前向传播
    def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids):
        # sequence_output, pooled_output, (hidden_states), (attentions)
        outputs = self.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
        # 通过BERT模型得到序列的向量表征
        sequence_output = outputs[0]
        # 通过BERT模型得到池化的向量表征
        pooled_output = outputs[1]  # [CLS]
        # 将[CLS]的向量表征拿出来，输入到分类器中，得到意图标签的预测
        intent_logits = self.intent_classifier(pooled_output)

        outputs = ((intent_logits),) + outputs[2:]

        # 1. 意图多标签分类
        if intent_label_ids is not None:
            # 如果意图标签的数量为1，则使用MSE
            if self.num_intent_labels == 1:
                intent_loss_fct = nn.MSELoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
            # 如果意图标签的数量不为1，则使用交叉熵损失
            else:
                intent_loss_fct = nn.CrossEntropyLoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1))

            outputs = (intent_loss,) + outputs
        # 返回(loss), logits, (hidden_states), (attentions)
        return outputs
